/*
 * Copyright (c) Kumo Inc. and affiliates.
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <melon/cancellation_token.h>
#include <melon/optional.h>
#include <melon/synchronization/detail/sleeper.h>

#include <turbo/log/logging.h>

#include <algorithm>
#include <new>
#include <thread>
#include <tuple>

namespace melon::detail {
    CancellationState::~CancellationState() {
        DKCHECK(head_ == nullptr);
        DKCHECK(!isLocked(state_.load(std::memory_order_relaxed)));
        DKCHECK(
            state_.load(std::memory_order_relaxed) < kTokenReferenceCountIncrement);
    }

    bool CancellationState::tryAddCallback(
        CancellationCallback *callback,
        bool incrementRefCountIfSuccessful) noexcept {
        // Try to acquire the lock, but abandon trying to acquire the lock if
        // cancellation has already been requested (we can just immediately invoke
        // the callback) or if cancellation can never be requested (we can just
        // skip registration).
        if (!tryLock([callback](std::uint64_t oldState) noexcept {
            if (isCancellationRequested(oldState)) {
                callback->invokeCallback();
                return false;
            }
            return canBeCancelled(oldState);
        })) {
            return false;
        }

        // We've acquired the lock and cancellation has not yet been requested.
        // Push this callback onto the head of the list.
        if (head_ != nullptr) {
            head_->prevNext_ = &callback->next_;
        }
        callback->next_ = head_;
        callback->prevNext_ = &head_;
        head_ = callback;

        if (incrementRefCountIfSuccessful) {
            // Combine multiple atomic operations into a single atomic operation.
            unlockAndIncrementTokenCount();
        } else {
            unlock();
        }

        // Successfully added the callback.
        return true;
    }

    void CancellationState::removeCallback(
        CancellationCallback *callback) noexcept {
        DKCHECK(callback != nullptr);

        lock();

        if (callback->prevNext_ != nullptr) {
            // Still registered in the list => not yet executed.
            // Just remove it from the list.
            *callback->prevNext_ = callback->next_;
            if (callback->next_ != nullptr) {
                callback->next_->prevNext_ = callback->prevNext_;
            }

            unlockAndDecrementTokenCount();
            return;
        }

        unlock();

        // Callback has either already executed or is executing concurrently on
        // another thread.

        if (signallingThreadId_ == std::this_thread::get_id()) {
            // Callback executed on this thread or is still currently executing
            // and is deregistering itself from within the callback.
            if (callback->destructorHasRunInsideCallback_ != nullptr) {
                // Currently inside the callback, let the requestCancellation() method
                // know the object is about to be destructed and that it should
                // not try to access the object when the callback returns.
                *callback->destructorHasRunInsideCallback_ = true;
            }
        } else {
            // Callback is currently executing on another thread, block until it
            // finishes executing.
            melon::detail::Sleeper sleeper;
            while (!callback->callbackCompleted_.load(std::memory_order_acquire)) {
                sleeper.wait();
            }
        }

        removeTokenReference();
    }

    bool CancellationState::requestCancellation() noexcept {
        if (!tryLockAndCancelUnlessCancelled()) {
            // Was already marked as cancelled
            return true;
        }

        // This thread marked as cancelled and acquired the lock

        signallingThreadId_ = std::this_thread::get_id();

        while (head_ != nullptr) {
            // Dequeue the first item on the queue.
            CancellationCallback *callback = head_;
            head_ = callback->next_;
            const bool anyMore = head_ != nullptr;
            if (anyMore) {
                head_->prevNext_ = &head_;
            }
            // Mark this item as removed from the list.
            callback->prevNext_ = nullptr;

            // Don't hold the lock while executing the callback
            // as we don't want to block other threads from
            // deregistering callbacks.
            unlock();

            // TRICKY: Need to store a flag on the stack here that the callback
            // can use to signal that the destructor was executed inline
            // during the call.
            // If the destructor was executed inline then it's not safe to
            // dereference 'callback' after 'invokeCallback()' returns.
            // If the destructor runs on some other thread then the other
            // thread will block waiting for this thread to signal that the
            // callback has finished executing.
            bool destructorHasRunInsideCallback = false;
            callback->destructorHasRunInsideCallback_ = &destructorHasRunInsideCallback;

            callback->invokeCallback();

            if (!destructorHasRunInsideCallback) {
                callback->destructorHasRunInsideCallback_ = nullptr;
                callback->callbackCompleted_.store(true, std::memory_order_release);
            }

            if (!anyMore) {
                // This was the last item in the queue when we dequeued it.
                // No more items should be added to the queue after we have
                // marked the state as cancelled, only removed from the queue.
                // Avoid acquiring/releasing the lock in this case.
                return false;
            }

            lock();
        }

        unlock();

        return false;
    }

    void CancellationState::lock() noexcept {
        melon::detail::Sleeper sleeper;
        std::uint64_t oldState = state_.load(std::memory_order_relaxed);
        do {
            while (isLocked(oldState)) {
                sleeper.wait();
                oldState = state_.load(std::memory_order_relaxed);
            }
        } while (!state_.compare_exchange_weak(
            oldState,
            oldState | kLockedFlag,
            std::memory_order_acquire,
            std::memory_order_relaxed));
    }

    void CancellationState::unlock() noexcept {
        state_.fetch_sub(kLockedFlag, std::memory_order_release);
    }

    void CancellationState::unlockAndIncrementTokenCount() noexcept {
        state_.fetch_sub(
            kLockedFlag - kTokenReferenceCountIncrement, std::memory_order_release);
    }

    void CancellationState::unlockAndDecrementTokenCount() noexcept {
        auto oldState = state_.fetch_sub(
            kLockedFlag + kTokenReferenceCountIncrement, std::memory_order_acq_rel);
        if (oldState < (kLockedFlag + 2 * kTokenReferenceCountIncrement)) {
            delete this;
        }
    }

    bool CancellationState::tryLockAndCancelUnlessCancelled() noexcept {
        melon::detail::Sleeper sleeper;
        std::uint64_t oldState = state_.load(std::memory_order_acquire);
        while (true) {
            if (isCancellationRequested(oldState)) {
                return false;
            } else if (isLocked(oldState)) {
                sleeper.wait();
                oldState = state_.load(std::memory_order_acquire);
            } else if (state_.compare_exchange_weak(
                oldState,
                oldState | kLockedFlag | kCancellationRequestedFlag,
                std::memory_order_acq_rel,
                std::memory_order_acquire)) {
                return true;
            }
        }
    }

    template<typename Predicate>
    bool CancellationState::tryLock(Predicate predicate) noexcept {
        melon::detail::Sleeper sleeper;
        std::uint64_t oldState = state_.load(std::memory_order_acquire);
        while (true) {
            if (!predicate(oldState)) {
                return false;
            } else if (isLocked(oldState)) {
                sleeper.wait();
                oldState = state_.load(std::memory_order_acquire);
            } else if (state_.compare_exchange_weak(
                oldState,
                oldState | kLockedFlag,
                std::memory_order_acquire)) {
                return true;
            }
        }
    }
} // namespace melon::detail
